iT邦幫忙

2021 iThome 鐵人賽

DAY 25
0
AI & Data

手寫中文字之影像辨識系列 第 25

【第25天】部署API服務-Python Flask

  • 分享至 

  • xImage
  •  

摘要

  1. 導入套件
  2. 模型初始化資料
  3. API初始化
  4. server_uuid
  5. 轉換圖片格式
  6. 模型辨識手寫中文字
  7. 檢查預測結果是否為字串
  8. API服務
  9. 啟用API服務

內容

  1. 導入套件

    import base64
    import datetime
    import hashlib
    import time
    from argparse import ArgumentParser
    import multiprocessing
    import cv2
    import numpy as np
    from flask import Flask
    from flask import jsonify
    from flask import request
    from img_gray import process_img
    from PIL import Image
    import torch
    from torch import nn
    from torchvision.transforms import Compose, ToTensor,Resize,ColorJitter,Normalize
    import torchvision.models as models
    import pandas as pd
    from R_model_load import Model
    from tensorflow.keras.preprocessing import image
    from tensorflow.keras.models import load_model
    import numpy as np
    import os
    from torch.optim.swa_utils import AveragedModel, update_bn, SWALR
    
  2. 模型初始化資料

    2.1 資料內容

    • 隊長Email
    • uuid加密
    • CPU運算:因GCP免費試用,開啟的VM無GPU,故以CPU運算。
    • 4個Model初始化:3個影像辨識模型+1個SVM模型(用以進一步判斷isnull)。
    • 接收圖片的Log與官方800字清單
    • 模型組合之權重與閾值表

    2.2 程式碼

    app = Flask(__name__)
    
    # 隊長email
    CAPTAIN_EMAIL = 'XXXXXX@gmail.com'
    
    # uuid加密
    SALT = '1688'
    
    # CPU運算(關閉GPU)
    os.environ["CUDA_VISIBLE_DEVICES"]="-1"
    
    # 4個Model初始化
    # Xception
    model_Xception = None
    # InceptionResNetV2
    model_V2 = None
    # Densenet201
    model_swa = None
    # R_SVM_model
    model_R = None
    
    # 接收的圖片Log檔
    file1 = open('./pic_base64.txt', 'a')
    # 官方800字清單
    words_path = r'./800_words.txt'
    file2 = open(words_path, 'rt', encoding='Big5')
    labels = list(file2.read())
    
    # 模型組合之權重與閾值表
    # 載入表格
    weight_df = pd.read_csv("./model_weight_final.csv", encoding="Big5")
    # DenseNet權重
    weight_swa = weight_df['wei_ex6'].values
    # InceptionResNetV2權重
    weight_V2 = weight_df['wei_ex5'].values
    # Xception權重
    weight_Xception = weight_df['wei_3'].values
    
  3. API初始化

    3.1 before_first_request:在處理第一個request前,先執行API初始化,用以載入模型。

    3.2 程式碼

    @app.before_first_request
    def init():
        # Xception
        global model_Xception
        model_Xception = load_model('./Xception_retrained_v2.h5')
    
        # InceptionResNetV2
        global model_V2
        model_V2 = load_model('./InceptionResNetV2.h5')
    
        # DenseNet201
        global model_swa
        model_densenet = models.densenet201(num_classes=800)
        model_path = './swa_densenet201.pth'
        model_fang = model_densenet
        model_swa = AveragedModel(model_fang)
        model_swa.eval()
        model_swa.load_state_dict(torch.load(model_path,
                                  map_location=torch.device('cpu')))
    
        # SVM模型
        global model_R
        MODEL_PATH = "./model_svm_v3"
        model_R = Model().load(MODEL_PATH)
        print('====================API初始化完成init====================')
    
  4. 產出server_uuid

    def generate_server_uuid(input_string):
        s = hashlib.sha256()
        data = (input_string + SALT).encode("utf-8")
        s.update(data)
        server_uuid = s.hexdigest()
        return server_uuid
    
  5. 檢查預測結果是否為字串:供後續輸出預測結果之前,判定資料型態。

    def _check_datatype_to_string(prediction):
        if isinstance(prediction, str):
            return True
        raise TypeError('Prediction is not in string type.')
    
  6. 將接收到的圖片轉換格式

    6.1 流程

    • 將base64編碼轉換成numpy格式。
    • 圖片預處理:將圖片轉換成灰階。
    • 將接收的圖片,儲存到Log檔:紀錄比賽圖片樣本,供後續改善模型之用。
    • 將圖片轉換成模型input格式

    6.2 程式碼

    def base64_to_binary_for_cv2(image_64_encoded):
        # base64轉numpy
        img_base64_binary = image_64_encoded.encode("utf-8")
        img_binary = base64.b64decode(img_base64_binary)
        image = cv2.imdecode(np.frombuffer(img_binary, np.uint8),
                             cv2.IMREAD_COLOR)
    
        # 圖片預處理
        image = process_img(image)
        image = Image.fromarray(cv2.cvtColor(image,cv2.COLOR_GRAY2RGB))
        image_for_tensorflow = np.asarray(image)
    
        # 將接收的圖片,儲存到Log檔
        file1.write(image_64_encoded + '\n')
    
        # Xception之input圖片格式
        image_for_Xception = cv2.resize(image_for_tensorflow, (80,80),
                                        interpolation=cv2.INTER_CUBIC)
        image_for_Xception = np.expand_dims(image_for_Xception, axis=0)
        image_for_Xception = image_for_Xception / 255
    
        # InceptionResNetV2之input圖片格式
        image_for_V2 = cv2.resize(image_for_tensorflow, (150 , 150),
                                  interpolation=cv2.INTER_CUBIC)
        image_for_V2 = np.expand_dims(image_for_V2, axis=0)
        image_for_V2 = image_for_V2 / 255
    
        # DenseNet201之input圖片格式
        transforms = Compose([ColorJitter(brightness=(1.5, 1.5),
                              contrast=(6, 6), saturation=(1, 1),
                              hue=(-0.1, 0.1)), ToTensor(),
                              Normalize((0.5,), (0.5,))])
        image_for_swa = image.resize((80, 80), Image.ANTIALIAS)
        image_for_swa = transforms(image_for_swa)
    
        return image_for_Xception,image_for_V2,image_for_swa
    
  7. 模型辨識手寫中文字

    7.1 流程

    • 計算3個模型之800字機率,並乘以加權分數。
    • 將800字的加權機率進行加總,取得新的800字機率。
    • 從新的800字機率中,取機率值最大的那個字,做為預測結果。
    • 以閾值判斷,該字是否屬於800字內。若機率大於閾值,輸出該字;反之,則輸出isnull。
    • 檢查預測結果是否為字串。

    7.2 程式碼

    def predict(image_for_Xception,image_for_V2,image_for_swa):
        # InceptionResNetV2 predict的機率加權
        # 機率向量
        pred_V2 = model_V2.predict(image_for_V2)[0]
        # 乘上權重的新機率向量
        new_V2_prob = pred_V2 * weight_V2
    
        # Xception predict的機率加權
        # 機率向量
        pred_Xception = model_Xception.predict(image_for_Xception)[0] 
        # 乘上權重的新機率向量
        new_Xception_prob = pred_Xception * weight_Xception 
    
        # DenseNet201 predict的機率加權
        img = image_for_swa.view(1, 3, 80, 80)
        output = model_swa(img)
        output = output.view(-1, 800)
        output_prob = nn.functional.softmax(output, dim=1)
        # 機率向量
        output_prob_np = output_prob.cpu().detach().numpy()[0]
        # 乘上權重取得新機率向量
        new_swa_prob = output_prob_np * weight_swa 
    
        # 三個模型向量相加取得新的向量,判定手寫中文字
        new_prob = new_swa_prob + new_Xception_prob + new_V2_prob
        max_prob = np.max(new_prob)
        pred_word = np.argmax(new_prob)
    
        # 讀取該手寫中文字的閾值
        judge = labels[pred_word]
        mean_prob = weight_df[weight_df["word"] == judge]["mean_prob_new"].values
    
        # 判斷閾值
        if max_prob < mean_prob:
            prediction = "isnull"
        else:
            # 考慮加上SVM模型
            new_prob_2dim = new_prob[np.newaxis,:]
            # 丟入Rmodel預測是否為isnull,1為800字內,2為isnull
            pred = model_R.predict(new_prob_2dim)
            if pred == 2:
                prediction = "isnull"
            else:
                final_answer = np.argmax(new_prob)
                prediction = labels[final_answer]
    
        # 檢查預測結果是否為字串
        if _check_datatype_to_string(prediction):
            return prediction
    
  8. API服務(inference 資料傳輸格式:json)

    8.1 流程

    • 接收API用戶之request。
    • 取出json中image,並轉換成圖片格式。
    • 產出server_uuid:做為回傳時json內容之一。
    • 記錄錯誤log:供後續檢查API服務error之用。
    • 回傳預測結果給用戶

    8.2 程式碼

    @app.route('/inference', methods=['POST'])
    def inference():
        # 接收用戶request
        data = request.get_json(force=True)
    
        # 取image base64 encoded,並以cv2轉換格式
        image_64_encoded = data['image']
        image_for_Xception,image_for_V2,image_for_swa = base64_to_binary_for_cv2(image_64_encoded)
    
        # 產出server_uuid
        t = datetime.datetime.now()
        ts = str(int(t.utcnow().timestamp()))
        server_uuid = generate_server_uuid(CAPTAIN_EMAIL + ts)
    
        # 記錄API錯誤log
        try:
            answer = predict(image_for_Xception,
                             image_for_V2,
                             image_for_swa)
        except TypeError as type_error:
            raise type_error
        except Exception as e:
            raise e
        server_timestamp = time.time()
    
        # 回傳預測結果給用戶
        return jsonify({'esun_uuid': data['esun_uuid'],
                        'server_uuid': server_uuid,
                        'answer': answer,
                        'server_timestamp': server_timestamp})
    
    if __name__ == "__main__":
    
        arg_parser = ArgumentParser(usage='Usage: python ' + __file__ +
                                   ' [--port <port>] [--help]')
        arg_parser.add_argument('-p', '--port', default=8080, help='port')
        arg_parser.add_argument('-d', '--debug', default=True, help='debug')
        options = arg_parser.parse_args()
    
        app.run(host='0.0.0.0', port=options.port, debug=options.debug)
    
  9. 啟用API服務

    9.1 如何啟用API服務

    • 到GCP啟用VM
    • 開啟cmd與ssh連線
    • 前往目標資料夾
    • 輸入指令:python3 api_1.py

    9.2 成功啟用API服務(如下圖)


小結

  1. 完成最後一關「部署API服務」後,玉山競賽流程就到此暫時告一段落。
  2. 鑑於此次競賽仍有進步空間,後續幾天將透過實作,和大家分享一些可能提升模型辨識效果的方法。

讓我們繼續看下去...


上一篇
【第24天】部署API服務-GCP架設VM(二)
下一篇
【第26天】探討與改善-增加訓練樣本(一)
系列文
手寫中文字之影像辨識31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言